import os
import sys
import gzip
from collections import defaultdict
import pybedtools
from Bio import Align
from Bio.Seq import Seq
from Bio.Seq import reverse_complement
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO


from Bio.SeqIO import TwoBitIO


target = sys.argv[1]

assembly = "hg38"


def read_chromosome_sizes(assembly):
    sizes = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes/"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        assert chromosome.startswith("chr")
        size = int(size)
        sizes[chromosome] = size
    handle.close()
    return sizes

def read_genome(assembly):
    directory = '/osc-fs_home/scratch/mdehoon/Data/Genomes'
    filename = '%s.2bit' % assembly
    path = os.path.join(directory, assembly, filename)
    print("Reading genome ...", end="", flush=True)
    handle = open(path, 'rb')
    genome = TwoBitIO.TwoBitIterator(handle)
    genome = {record.id: record.seq for record in genome}
    print(" done")
    return genome

def create_chrM_files():
    sequence = genome['chrM']
    sequence = str(sequence).upper()
    sequence = Seq(sequence)
    records = []
    record = SeqRecord(sequence, id="chrM", description="")
    records.append(record)
    sequence = sequence.reverse_complement()
    record = SeqRecord(sequence, id="chrM_reverse", description="")
    records.append(record)
    filename = "chrM.fa"
    print("Writing", filename)
    handle = open(filename, 'w')
    written = SeqIO.write(records, handle, 'fasta')
    handle.close()
    print("%d sequences written" % written)
    length = len(sequence)
    filename = "chrM.psl"
    print("Writing", filename)
    handle = open(filename, 'w')
    matches = length
    misMatches = 0
    repMatches = 0
    nCount = 0
    qNumInsert = 0
    qBaseInsert = 0
    tNumInsert = 0
    tBaseInsert = 0
    strand = "+"
    qName = "chrM"
    qSize = length
    qStart = 0
    qEnd = length
    tName = "chrM"
    tSize = length
    tStart = 0
    tEnd = length
    blockCount = 1
    blockSizes = "%d," % length
    qStarts = "0,"
    tStarts = "0,"
    line = "%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\t%d\t%d\t%d\t%d\t%s\t%s\t%s\n" %  (matches,
                  misMatches,
                  repMatches,
                  nCount,
                  qNumInsert,
                  qBaseInsert,
                  tNumInsert,
                  tBaseInsert,
                  strand,
                  qName,
                  qSize,
                  qStart,
                  qEnd,
                  tName,
                  tSize,
                  tStart,
                  tEnd,
                  blockCount,
                  blockSizes,
                  qStarts,
                  tStarts,
                 )
    handle.write(line)
    strand = "-"
    qName = "chrM_reverse"
    line = "%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\t%d\t%d\t%d\t%d\t%s\t%s\t%s\n" %  (matches,
                  misMatches,
                  repMatches,
                  nCount,
                  qNumInsert,
                  qBaseInsert,
                  tNumInsert,
                  tBaseInsert,
                  strand,
                  qName,
                  qSize,
                  qStart,
                  qEnd,
                  tName,
                  tSize,
                  tStart,
                  tEnd,
                  blockCount,
                  blockSizes,
                  qStarts,
                  tStarts,
                 )
    handle.write(line)
    handle.close()

def create_tRNA_files():
    sequences = {}
    directory = "/osc-fs_home/mdehoon/LSA/ShortRNAPipeline/Annotation"
    filename = 'tRNA.fa'
    path = os.path.join(directory, assembly, filename)
    print("Reading", path)
    handle = open(path)
    records = SeqIO.parse(handle, 'fasta')
    records = list(records)
    lengths = {}
    for record in records:
        name = record.id
        length = len(record.seq)  # for the CCA suffix
        lengths[name] = length
    handle.close()
    names = lengths.keys()
    filename = "tRNA.bed"
    path = os.path.join(directory, assembly, filename)
    print("Reading", path)
    handle = open(path)
    lines = pybedtools.BedTool(handle)
    filename = "tRNA.psl"
    print("Writing", filename)
    handle = open(filename, 'w')
    mapped = 0
    for line in lines:
        trna, qName = line.name.split('|')
        assert trna == 'tRNA'
        assert qName in names
        tName = line.chrom
        if tName.endswith("_alt"):
            continue
        assert line.score == "."
        fields = line.fields
        assert fields[8] == "."
        blockSizes = fields[10].rstrip(',').split(",")
        blockStarts = fields[11].rstrip(',').split(",")
        blockSizes = [int(blockSize) for blockSize in blockSizes]
        blockStarts = [int(blockStart) for blockStart in blockStarts]
        blockCount = len(blockSizes)
        assert blockCount == len(blockStarts)
        tSize = sizes[tName]
        qSize = lengths[qName]
        length = sum(blockSizes)
        additional = qSize - length - 3  # 3 for CCA tag
        if additional > 0:
            print("%s: additional sequence of length %d" % (qName, additional))
        elif additional < 0:
            raise Exception("%s: mapped length is larger than sequence length" % qName)
        tNumInsert = 0
        tBaseInsert = 0
        qStarts = []
        tStarts = []
        strand = line.strand
        qStart = 0
        qEnd = qStart
        tStart = line.start
        tEnd = tStart
        for blockSize, blockStart in zip(blockSizes, blockStarts):
            if strand == '+':
                qStarts.append(qEnd)
            elif strand == '-':
                qStarts.append(qEnd + 3)  # CCA tag; PSL convention for negative strands
            qEnd += blockSize
            blockStart += tStart
            tStarts.append(blockStart)
            if tEnd > tStart:
                tBaseInsert += blockStart - tEnd
                tNumInsert += 1
            tEnd = blockStart + blockSize
        assert tEnd == line.end
        assert qEnd + 3 == qSize
        blockSizes = ",".join(map(str, blockSizes)) + ","
        qStarts = ",".join(map(str, qStarts)) + ","
        tStarts = ",".join(map(str, tStarts)) + ","
        fields = [length,
                  0,
                  0,
                  0,
                  0,
                  0,
                  tNumInsert,
                  tBaseInsert,
                  strand,
                  qName,
                  qSize,
                  qStart,
                  qEnd,
                  tName,
                  tSize,
                  tStart,
                  tEnd,
                  blockCount,
                  blockSizes,
                  qStarts,
                  tStarts]
        line = "\t".join([str(field) for field in fields]) + "\n"
        handle.write(line)
        mapped += 1
    handle.close()
    print("%d sequences unmapped" % (len(names) - mapped))
    filename = "tRNA.fa"
    print("Writing", filename)
    handle = open(filename, 'w')
    written = SeqIO.write(records, handle, 'fasta')
    handle.close()
    print("%d sequences written" % written)

def read_refseq_mapped_sequences(accessions, records):
    # blat scoring parameters
    match = 1
    mismatch = -1
    gap_open = -1
    gap_extend = -1 # Don't allow introns
    end_open = 0
    end_extend = 0
    aligner = Align.PairwiseAligner()
    aligner.open_gap_score = gap_open
    aligner.extend_gap_score = gap_extend
    aligner.end_open_gap_score = end_open
    aligner.end_extend_gap_score = end_extend
    aligner.match = match
    aligner.mismatch = mismatch
    chromosomes = {'NC_000001': 'chr1',
                   'NC_000002': 'chr2',
                   'NC_000003': 'chr3',
                   'NC_000004': 'chr4',
                   'NC_000005': 'chr5',
                   'NC_000006': 'chr6',
                   'NC_000007': 'chr7',
                   'NC_000008': 'chr8',
                   'NC_000009': 'chr9',
                   'NC_000010': 'chr10',
                   'NC_000011': 'chr11',
                   'NC_000012': 'chr12',
                   'NC_000013': 'chr13',
                   'NC_000014': 'chr14',
                   'NC_000015': 'chr15',
                   'NC_000016': 'chr16',
                   'NC_000017': 'chr17',
                   'NC_000018': 'chr18',
                   'NC_000019': 'chr19',
                   'NC_000020': 'chr20',
                   'NC_000021': 'chr21',
                   'NC_000022': 'chr22',
                   'NC_000023': 'chrX',
                   'NC_000024': 'chrY',
                   'NC_012920': 'chrM',
                  }
    exons = {}
    versions = defaultdict(list)
    directory = "/osc-fs_home/scratch/mdehoon/Data/NCBI/refseq"
    filenames = ("GCF_000001405.39_knownrefseq_alignments.gff3.gz",
                 "GCF_000001405.39_modelrefseq_alignments.gff3.gz"
                )
    # Downloaded from ftp://ftp.ncbi.nih.gov/refseq/H_sapiens/alignments/
    # on December 10, 2020.
    for filename in filenames:
        path = os.path.join(directory, filename)
        print("Reading", path)
        handle = gzip.open(path, 'rt')
        lines = pybedtools.BedTool(handle)
        for line in lines:
            accession_version, start, end, strand = line.attrs["Target"].split()
            assert strand=='+'
            genome_accession, version = line.chrom.split(".")
            chromosome = chromosomes.get(genome_accession)
            if chromosome is None:
                continue
            if accession_version not in accessions:
                continue
            alignment_ID = line.attrs['ID']
            if accession_version not in exons:
                exons[accession_version] = {}
            if alignment_ID not in exons[accession_version]:
                exons[accession_version][alignment_ID] = []
            strand = line.strand
            genome_start = line.start
            genome_end = line.end
            transcript_start = int(start)-1
            transcript_end = int(end)
            gap_count = int(line.attrs['gap_count'])
            if gap_count == 0 or 'Gap' not in line.attrs:
                # the gap_count is the gap count for the whole transcript;
                # if a gap is located in an exon other than the current one,
                # then gap_count is non-zero but Gap is not specified on this
                # line.
                exon = (chromosome,
                        strand,
                        transcript_start,
                        transcript_end,
                        genome_start,
                        genome_end
                       )
                exons[accession_version][alignment_ID].append(exon)
            else:
                terms = line.attrs['Gap'].split()
                assert len(terms) % 2 == 1
                assert gap_count >= (len(terms) - 1) // 2
                terms = [(term[0], int(term[1:])) for term in terms]
                if strand == '+':
                    for operation, length in terms:
                        if operation == "M":
                            exon = (chromosome,
                                    strand,
                                    transcript_start,
                                    transcript_start + length,
                                    genome_start,
                                    genome_start + length
                                   )
                            exons[accession_version][alignment_ID].append(exon)
                            transcript_start += length
                            genome_start += length
                        elif operation == "D":
                            genome_start += length
                        elif operation == "I":
                            transcript_start += length
                        else:
                            raise Exception("Unexpected gap %s" % gap)
                elif strand == '-':
                    for operation, length in terms:
                        if operation == "M":
                            exon = (chromosome,
                                    strand,
                                    transcript_start,
                                    transcript_start + length,
                                    genome_end-length,
                                    genome_end
                                   )
                            exons[accession_version][alignment_ID].append(exon)
                            transcript_start += length
                            genome_end -= length
                        elif operation == "D":
                            genome_end -= length
                        elif operation == "I":
                            transcript_start += length
                        else:
                            raise Exception("Unexpected gap %s" % gap)
                assert transcript_end == transcript_start
                assert genome_end == genome_start
            accession, version = accession_version.split(".")
            versions[accession].append(version)
        handle.close()
    current_exons = {}
    for accession in versions:
        version = max([int(version) for version in versions[accession]])
        accession_version = "%s.%d" % (accession, version)
        if accession_version not in exons:
            continue
        if len(exons[accession_version]) > 1:
            # RefSeq transcript is mapped to multiple genome regions.
            # Don't use the mapping locations, but include the RefSeq
            # transcript sequence for filtering
            print("MULTIMAPPER %s; ignoring for pslMap" % accession_version)
            continue
        for record in records:
            if record.id == accession_version:
                break
        else:
            raise Exception("Failed to find %s" % accession_version)
        alignment_IDs = list(exons[accession_version].keys())
        assert len(alignment_IDs) == 1
        alignment_ID = alignment_IDs[0]
        transcript_position = 0
        sequence = ""
        consistent = True
        gapped = False
        segments = []
        for exon in exons[accession_version][alignment_ID]:
            (chromosome, strand, transcript_start, transcript_end, genome_start, genome_end) = exon
            if transcript_start < transcript_position:
                raise Exception("back-splicing??")
            elif transcript_start > transcript_position:
                gap = transcript_start - transcript_position
                print("GAPPED %s (%d nucleotides gap); ignoring for pslMap" % (record.id, gap))
                gapped = True
                break
            if transcript_end - transcript_start != genome_end - genome_start:
                consistent = False
            sequence = str(genome[chromosome][genome_start:genome_end]).upper()
            segments.append(sequence)
            transcript_position = transcript_end
        if gapped:
            continue
        if strand == '-':
            segments = map(reverse_complement, segments)
        sequence = "".join(segments)
        if not consistent:
            # we found a discrepancy in exon length between the RefSeq
            # transcript and its genome mapping. Switch to the genome
            # sequence to keep the genome mapping locations consistent.
            score = aligner.score(str(record.seq), sequence)
            if score > 0.99 * max(len(str(record.seq)), len(sequence)):
                print("%s (%s %s): replacing RefSeq sequence by genome sequence (score = %d; RefSeq length = %d, genome length = %d)" % (chromosome, strand, record.id, score, len(record.seq), len(sequence)))
            else:
                print("%s (%s %s): discrepancy between RefSeq sequence and genome sequence (score = %d; RefSeq length = %d, genome length = %d); ignoring for pslMap" % (chromosome, strand, record.id, score, len(record.seq), len(sequence)))
            record.seq = Seq(sequence)
        current_exons[accession_version] = exons[accession_version]
    print("REMAINING: %d out of %d" % (len(current_exons), len(versions)))
    return current_exons

def read_refseq_rRNA_sequences():
    directory = "/osc-fs_home/scratch/mdehoon/Data/NCBI/refseq"
    filenames = []
    for filename in os.listdir(directory):
        terms = filename.split('.')
        if len(terms) != 5:
            continue
        if terms[0] != 'human':
            continue
        if terms[2] != 'rna':
            continue
        if terms[3] != 'fna':
            continue
        if terms[4] != 'gz':
            continue
        filenames.append(filename)
    filenames.sort(key=lambda filename: int(filename.split('.')[1]))
    records = []
    for filename in filenames:
        path = os.path.join(directory, filename)
        print("Reading", path)
        handle = gzip.open(path, 'rt')
        for record in SeqIO.parse(handle, 'fasta'):
            accession, description = record.description.split(None, 1)
            terms = description.split(", ", 2)
            if terms[0] == 'Homo sapiens RNA' and terms[2] == 'ribosomal RNA':
                records.append(record)
        handle.close()
    return records

def create_rRNA_files():
    lengths = {}
    records = read_refseq_rRNA_sequences()
    for record in records:
        accession = record.id
        length = len(record.seq)
        lengths[accession] = length
    accessions = lengths.keys()
    exons = read_refseq_mapped_sequences(accessions, records)
    lines = write_exon_locations(exons, lengths)
    lines = list(lines)
    lines.sort(key=get_genome_location)
    filename = "rRNA.psl"
    print("Writing", filename)
    handle = open(filename, 'w')
    for line in lines:
        line = "\t".join([str(word) for word in line]) + "\n"
        handle.write(line) 
    handle.close()
    missing = 0
    for record in records:
        accession = record.id
        if accession not in exons:
            missing += 1
    print("%d sequences unmapped" % missing)
    filename = "rRNA.fa"
    print("Writing", filename)
    handle = open(filename, 'w')
    written = SeqIO.write(records, handle, 'fasta')
    handle.close()
    print("%d sequences written" % written)

def check_description_histone(description):
    prefix = "PREDICTED: "
    if description.startswith(prefix):
        description = description[len(prefix):]
    description, term = description.rsplit(", ", 1)
    if term != 'mRNA':
        return False
    prefix = "Homo sapiens "
    if not description.startswith(prefix):
        return False
    description = description[len(prefix):]
    if 'linker histone' in description:
        return True
    if 'clustered histone' in description:
        return True
    if 'variant histone' in description:
        return True
    prefixes = ("H2A.J histone",
                "H2A.P histone",
                "H2A.W histone",
                "H2B.S histone",
                "H2B.U histone",
                "H2B.W histone",
                "H3.2 histone",
                "H3.3 histone",
                "H3.4 histone",
                "H3.5 histone",
                "H3.Y histone",
                "H4 histone",
                "histone H2B",
                "macroH2A.1 histone",
                "macroH2A.2 histone",
               )
    for prefix in prefixes:
        if description.startswith(prefix):
            return True
    return False

def parse_description(description):
    if description.startswith("Homo sapiens RNA, 7SL, cytoplasmic"):
        terms = description.split(", ", 2)
        assert terms[0]=='Homo sapiens RNA'
        assert terms[1]=='7SL'
        term = terms[2]
        terms = term.rsplit(", ", 1)
        i = terms[0].find('(')
        j = terms[0].find(')', i)
        name = terms[0][i+1:j]
        term = terms[0][:i-1]
        if term.endswith(", pseudogene"):
            assert terms[1]=='non-coding RNA'
            term = term[:-len(", pseudogene")]
        else:
            assert terms[1]=='small cytoplasmic RNA'
        cytoplasmic, number = term.split()
        assert cytoplasmic=='cytoplasmic'
        number = int(number)
        return 'scRNA'
    if description.startswith("Homo sapiens RNA component of signal recognition particle"):
        terms = description.split(", ")
        assert terms[1] == 'small cytoplasmic RNA'
        term = terms[0]
        terms = term.rsplit(None, 2)
        assert terms[0] == "Homo sapiens RNA component of signal recognition particle"
        term = terms[2]
        assert term.startswith('(')
        assert term.endswith(')')
        name = term[1:-1]
        term = terms[1]
        assert term in ('7SL1', '7SL2', '7SL3')
        assert name == 'RN' + term
        return 'scRNA'
    if description == "Homo sapiens RNA component of 7SK nuclear ribonucleoprotein (RN7SK), small nuclear RNA":
        return 'snRNA'
    else:
        terms = description.split(", ", 2)
        if terms[0] == 'Homo sapiens RNA' and terms[2] == 'small nuclear RNA':
            return 'snRNA'
    if description.startswith("Homo sapiens vault RNA"):
        terms = description.split(", ")
        assert len(terms)==2
        description, number, name = terms[0].rsplit(None, 2)
        assert description=="Homo sapiens vault RNA"
        assert name.startswith("(")
        assert name.endswith(")")
        name = name[1:-1]
        assert name.startswith("VTRNA1") or name.startswith("VTRNA2")
        return 'vRNA'
    if "small nucleolar RNA host gene" in description:
        terms = description.split(", ")
        assert terms[-1] == 'long non-coding RNA'
        terms = terms[:-1]
        if description.startswith("Homo sapiens small nucleolar RNA host gene"):
            if terms[-1].startswith("transcript variant"):
                words = terms[-1].rsplit(None, 1)
                assert len(words) == 2
                assert words[0] == "transcript variant"
                number = int(words[1])
                terms = terms[:-1]
        else:
            description, gene = terms[-1].rsplit(None, 1)
            assert gene.startswith("(")
            assert gene.endswith(")")
            gene = gene[1:-1]
            if gene == 'MEG8':
                assert terms[0] == "Homo sapiens maternally expressed 8"
                assert description == "small nucleolar RNA host gene"
            elif gene.startswith("SNHG"):
                description, number = description.rsplit(None, 1)
                assert description == "Homo sapiens small nucleolar RNA host gene"
                number = int(number)
        return 'snhg'
    if description.startswith("Homo sapiens small nucleolar RNA"):
        terms = description.split(", ")
        assert len(terms)==3
        assert terms[0]=="Homo sapiens small nucleolar RNA"
        assert terms[2]=="small nucleolar RNA"
        description, name = terms[1].rsplit(None, 1)
        assert name.startswith("(")
        assert name.endswith(")")
        name = name[1:-1]
        if '(pseudogene)' in description:
            description, pseudogene = description.rsplit(None, 1)
            assert pseudogene == '(pseudogene)'
        elif 'pseudogene' in description:
            description, pseudogene, paralog = description.rsplit(None, 2)
            assert pseudogene == 'pseudogene'
            paralog = int(paralog)
        category, number = description.rsplit(None, 1)
        if category == "H/ACA box":
            assert name.startswith("SNORA")
        elif category ==  "C/D box":
            assert name.startswith("SNORD")
        else:
            raise Exception("Unknown category %s" % category)
        if '-' in number:
            number, paralog = number.split('-')
            paralog = int(paralog)
            assert paralog > 0
        if number[-1] in 'ABCDEFGHIJK':
            number = number[:-1]
        if number[0] == 'U':
            number = number[1:]
        number = int(number)
        return 'snoRNA'
    if description.startswith("Homo sapiens RNA"):
        try:
            description, localization = description.split("; ")
            assert localization == "nuclear gene for mitochondrial product"
        except ValueError:
            pass
        terms = description.split(", ")
        if len(terms) != 3:
            assert 'small nucleolar RNA' not in description
        elif terms[2] in ('mRNA', 'non-coding RNA', 'long non-coding RNA'):
            assert 'small nucleolar RNA' not in description
        else:
            assert terms[0] in ("Homo sapiens RNA",
                               ), description
            if terms[2]!='small nucleolar RNA':
                assert 'small nucleolar RNA' not in description
            else:
                description, name = terms[1].rsplit(None, 1)
                number, category = description.split(None, 1)
                assert category=='small nucleolar'
                assert name.startswith('(')
                assert name.endswith(')')
                name = name[1:-1]
                assert name=="RN%s" % number
                return 'snoRNA'
    else:
        assert ('small nucleolar RNA binding protein' in description
             or 'small nucleolar RNA host gene' in description
             or ('small nucleolar RNA' not in description
             and 'snoRNA' not in description))
    prefix = "Homo sapiens small Cajal body-specific RNA "
    if description.startswith(prefix):
        terms = description.split(", ")
        assert terms[1]=='guide RNA'
        return 'scaRNA'
    else:
        assert "Cajal body-specific RNA" not in description
    terms = description.rsplit(", ", 1)
    term = terms[1]
    if description.startswith("Homo sapiens RNA, Ro60-associated Y"):
        assert term == 'Y RNA'
        return 'yRNA'
    else:
        assert term != 'Y RNA'
    if check_description_histone(description):
        return 'histone'
    if term=='RNase MRP RNA':
        return 'RMRP'
    if term=='RNase P RNA':
        return 'RPPH'
    if term=='telomerase RNA':
        return 'TERC'
    term = terms[0]
    if term == "Homo sapiens brain cytoplasmic RNA 1 (BCYRN1)":
        return 'scRNA'
    terms = description.rsplit(", ", 2)
    term = terms[-1]
    if term == 'long non-coding RNA' and terms[0] == 'Homo sapiens metastasis associated lung adenocarcinoma transcript 1 (MALAT1)':
        term = terms[1]
        assert term in ('transcript variant 1',
                        'transcript variant 2',
                        'transcript variant 3')
        return "MALAT1"
    terms = description.rsplit(", ", 1)
    term = terms[1]
    if term in ('mRNA',
                'mRNA; nuclear gene for mitochondrial product',
                'mRNA; nuclear gene for mitochondrial products',
                'mRNA; nuclear genes for mitochondrial products',
                'partial mRNA',
               ):
        return 'mRNA'
    if term in ('long non-coding RNA',
                'non-coding RNA',
                'ncRNA',
                'antisense RNA',
                'misc_RNA',
               ):
        return 'lncRNA'
    if term == "microRNA":
        return "microRNA"
    if term == "ribosomal RNA":
        return "rRNA"
    if term == "small cytoplasmic RNA":
        return "scRNA"
    prefix = "Homo sapiens small NF90 (ILF3) associated RNA"
    if description.startswith(prefix):
        assert term == 'small nuclear RNA'
        return 'snar'
    print(term)
    raise Exception("Failed to parse %s" % description)

def parse_refseq_fasta(target):
    directory = "/osc-fs_home/scratch/mdehoon/Data/NCBI/refseq"
    filenames = []
    for filename in os.listdir(directory):
        terms = filename.split('.')
        if len(terms) != 5:
            continue
        if terms[0] != 'human':
            continue
        if terms[2] != 'rna':
            continue
        if terms[3] != 'fna':
            continue
        if terms[4] != 'gz':
            continue
        filenames.append(filename)
    filenames.sort(key=lambda filename: int(filename.split('.')[1]))
    records = []
    for filename in filenames:
        path = os.path.join(directory, filename)
        print("Reading", path)
        handle = gzip.open(path, 'rt')
        for record in SeqIO.parse(handle, 'fasta'):
            accession, description = record.description.split(None, 1)
            if parse_description(description) == target:
                records.append(record)
        handle.close()
    return records
        
def write_exon_locations(exons, lengths):
    for accession in lengths:
        if accession not in exons:
            continue
        for alignment_ID in exons[accession]:
            position = 0
            current_chromosome = None 
            current_strand = None 
            blocks = []
            for exon in sorted(exons[accession][alignment_ID]):
                chromosome, strand, transcript_start, transcript_end, genome_start, genome_end = exon
                if current_chromosome is None:
                    current_chromosome = chromosome
                    current_strand = strand
                else:
                    assert current_chromosome == chromosome
                    assert current_strand == strand
                if position != transcript_start:
                    print("%s: transcript_start = %d, position = %d; SKIPPING" % (accession, transcript_start, position))
                    break
                position = transcript_end
                block = (genome_start, genome_end)
                blocks.append(block)
            else: # no errors
                qName = accession
                qSize = lengths[accession]
                polyAtail = qSize - position
                if polyAtail > 0:
                    print('%s: Poly(A)-tail length %d' % (accession, polyAtail))
                if polyAtail < 0:
                    raise Exception("%s: sequence shorter than mapped length" % accession)
                if strand=='-':
                    blocks = blocks[::-1]
                tName = chromosome
                tSize = sizes[chromosome]
                blockCount = len(blocks)
                blockSizes = []
                qStarts = []
                qStart = 0
                qEnd = 0
                tStarts = []
                tEnd = None
                tNumInsert = 0
                tBaseInsert = 0
                for start, end in blocks:
                    blockSize = end - start
                    blockSizes.append(blockSize)
                    if strand == '+':
                        qStarts.append(qEnd)
                    elif strand == '-':
                        qStarts.append(qEnd + polyAtail)  # PSL convention for negative strands
                    qEnd += blockSize
                    tStarts.append(start)
                    if tEnd is not None:
                        tBaseInsert += start - tEnd
                        tNumInsert += 1
                    tEnd = end
                length = sum(blockSizes)
                tStart = tStarts[0]
                blockSizes = ",".join(map(str, blockSizes)) + ","
                qStarts = ",".join(map(str, qStarts)) + ","
                tStarts = ",".join(map(str, tStarts)) + ","
                fields = [length,
                          0,
                          0,
                          0,
                          0,
                          0,
                          tNumInsert,
                          tBaseInsert,
                          strand,
                          qName,
                          qSize,
                          qStart,
                          qEnd,
                          tName,
                          tSize,
                          tStart,
                          tEnd,
                          blockCount,
                          blockSizes,
                          qStarts,
                          tStarts]
                yield fields

def get_genome_location(fields):
    matches, misMatches, repMatches, nCount, qNumInsert, qBaseInsert, tNumInsert, tBaseInsert, strand, qName, qSize, qStart, qEnd, tName, tSize, tStart, tEnd, blockCount, blockSizes, qStarts, tStarts = fields
    return (tName, tStart, tEnd, strand)

def create_refseq_files(target):
    lengths = {}
    records = parse_refseq_fasta(target)
    for record in records:
        accession = record.id
        length = len(record.seq)
        lengths[accession] = length
    accessions = lengths.keys()
    exons = read_refseq_mapped_sequences(accessions, records)
    lines = write_exon_locations(exons, lengths)
    lines = list(lines)
    lines.sort(key=get_genome_location)
    filename = "%s.psl" % target
    print("Writing", filename)
    handle = open(filename, 'w')
    for line in lines:
        line = "\t".join([str(word) for word in line]) + "\n"
        handle.write(line) 
    handle.close()
    missing = 0
    for record in records:
        accession = record.id
        if accession not in exons:
            missing += 1
    print("%d sequences unmapped" % missing)
    filename = "%s.fa" % target
    print("Writing", filename)
    handle = open(filename, 'w')
    written = SeqIO.write(records, handle, 'fasta')
    handle.close()
    print("%d sequences written" % written)

def create_gencode_files(target):
    lengths = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Gencode"
    filename = 'gencode.v34.transcripts.fa.gz'
    path = os.path.join(directory, filename)
    def find_records(path):
        with gzip.open(path, 'rt') as handle:
            records = SeqIO.parse(handle, 'fasta')
            for record in records:
                terms = record.id.split('|')
                accession = terms[0]
                record.id = accession
                length = len(record.seq)
                lengths[accession] = length
                yield record
    records = find_records(path)
    filename = "%s.fa" % target
    print("Writing", filename)
    with open(filename, 'w') as handle:
        written = SeqIO.write(records, handle, 'fasta')
    print("%d sequences written" % written)
    accessions = lengths.keys()
    chromosomes = {}
    strands = {}
    exons = {}
    filename = "gencode.v34.annotation.gtf.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, 'rt')
    lines = pybedtools.BedTool(handle)
    for line in lines:
        if line.fields[2] != 'exon':
            continue
        accession = line.attrs['transcript_id']
        assert accession in accessions
        chromosome = line.chrom
        strand = line.strand
        start = line.start
        end = line.end
        exon = (start, end)
        if chromosomes.get(accession) is None:
            chromosomes[accession] = chromosome
            strands[accession] = strand
            exons[accession] = []
        else:
            assert chromosomes[accession] == chromosome
            assert strands[accession] == strand
        exons[accession].append(exon)
    handle.close()
    # Check if all transcripts were mapped
    for accession in accessions:
        assert accession in exons
    filename = "%s.psl" % target
    print("Writing %d sequences to %s" % (len(accessions), filename))
    handle = open(filename, 'w')
    for accession in sorted(chromosomes):
        exons[accession].sort()
        tName = chromosomes[accession]
        chromStart = exons[accession][0][0]
        chromEnd = exons[accession][-1][1]
        qName = accession
        strand = strands[accession]
        tSize = sizes[tName]
        qStart = 0
        qEnd = 0
        qStarts = []
        tStarts = []
        blockCount = 0
        blockSizes = []
        for blockStart, blockEnd in exons[accession]:
            qStarts.append(qEnd)
            tStarts.append(blockStart)
            blockSize = blockEnd - blockStart
            blockSizes.append(blockSize)
            blockCount += 1
            qEnd += blockSize
        tStart = tStarts[0]
        tEnd = blockStart + blockSize
        qSize = qEnd - qStart
        blockSizes = ",".join([str(blockSize) for blockSize in blockSizes]) + ","
        qStarts = ",".join([str(qStart) for qStart in qStarts]) + ","
        tStarts = ",".join([str(tStart) for tStart in tStarts]) + ","
        fields = [qSize,  # matches
                  0,      # misMatches
                  0,      # repMatches
                  0,      # nCount
                  0,      # qNumInsert
                  0,      # qBaseInsert
                  0,      # tNumInsert
                  0,      # tBaseInsert
                  strand, # strand
                  qName,  # qName
                  qSize,  # qSize
                  qStart, # qstart
                  qSize,  # qEnd
                  tName,  # tName
                  tSize,  # tSize
                  tStart, # tStart
                  tEnd,   # tEnd
                  blockCount,  # blockCount
                  blockSizes,  # blockSizes
                  qStarts,     # qStarts
                  tStarts]     # tStarts
        line = "\t".join([str(field) for field in fields]) + "\n"
        handle.write(line)
    handle.close()

def create_fantomcat_files(target):
    lengths = {}
    directory = "/osc-fs_home/mdehoon/Data/Fantom6/FANTOMCAT"
    filename = "F6_CAT.transcript.bed.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, 'rt')
    lines = pybedtools.BedTool(handle)
    filename = 'fantomcat.psl'
    print("Writing", filename)
    psl_handle = open(filename, 'wt')
    filename = 'fantomcat.fa'
    print("Writing", filename)
    fasta_handle = open(filename, 'wt')
    def generate_records():
        for line in lines:
            fields = list(line.fields)
            fields[4] = "0" # score
            fields[8] = "0" # itemRgb
            chromosome = line.chrom
            strand = line.strand
            sequence = ""
            qName = line.name
            qSize = 0
            tName = chromosome
            tSize = sizes[tName]
            blockCount = 0
            blockSizes = line.fields[10].rstrip(",").split(",")
            blockStarts = line.fields[11].rstrip(",").split(",")
            qStarts = []
            tStarts = []
            for blockSize, blockStart in zip(blockSizes, blockStarts):
                qStarts.append(qSize)
                blockSize = int(blockSize)
                tStart = line.start + int(blockStart)
                tStarts.append(tStart)
                sequence += str(genome[chromosome][tStart:tStart+blockSize])
                qSize += blockSize
                blockCount += 1
            tStart = line.start
            tEnd = line.end
            qStart = 0
            qEnd = 0
            blockSizes = ",".join([str(blockSize) for blockSize in blockSizes]) + ","
            qStarts = ",".join([str(qStart) for qStart in qStarts]) + ","
            tStarts = ",".join([str(tStart) for tStart in tStarts]) + ","
            fields = [qSize,  # matches
                      0,      # misMatches
                      0,      # repMatches
                      0,      # nCount
                      0,      # qNumInsert
                      0,      # qBaseInsert
                      0,      # tNumInsert
                      0,      # tBaseInsert
                      strand, # strand
                      qName,  # qName
                      qSize,  # qSize
                      qStart, # qstart
                      qSize,  # qEnd
                      tName,  # tName
                      tSize,  # tSize
                      tStart, # tStart
                      tEnd,   # tEnd
                      blockCount,  # blockCount
                      blockSizes,  # blockSizes
                      qStarts,     # qStarts
                      tStarts]     # tStarts
            sequence = sequence.upper()
            if line.strand == '+':
                pass
            elif line.strand == '-':
                sequence = reverse_complement(sequence)
            else:
                raise Exception("Unexpected strand")
            if len(sequence) == 0:
                print(line.name)
                print(str(line))
                raise Exception
            record = SeqRecord(Seq(sequence), id=line.name)
            line = "\t".join([str(field) for field in fields]) + "\n"
            psl_handle.write(line)
            yield record
    records = generate_records()
    SeqIO.write(records, fasta_handle, 'fasta')
    fasta_handle.close()
    psl_handle.close()

def create_novel_files(target):
    lengths = {}
    filename = "transcripts.novel.curated.bed"
    print("Reading", filename)
    handle = open(filename)
    lines = pybedtools.BedTool(handle)
    filename = 'novel.bed'
    print("Writing", filename)
    bed_handle = open(filename, 'wt')
    filename = 'novel.fa'
    print("Writing", filename)
    fasta_handle = open(filename, 'wt')
    def generate_records():
        previous = None
        counter = 0
        for line in lines:
            fields = list(line.fields)
            fields[4] = "0" # score
            fields[8] = "0" # itemRgb
            exons = fields[10:12]
            if exons == previous:
                continue
            previous = exons
            counter += 1
            name = "novel_%d" % counter
            fields[3] = name
            interval = pybedtools.create_interval_from_list(fields)
            bed_handle.write(str(interval))
            chromosome = line.chrom
            sequence = ""
            chromStart = line.start
            blockSizes = line.fields[10].rstrip(",").split(",")
            blockStarts = line.fields[11].rstrip(",").split(",")
            for blockSize, blockStart in zip(blockSizes, blockStarts):
                blockSize = int(blockSize)
                blockStart = int(blockStart) + chromStart
                sequence += str(genome[chromosome][blockStart:blockStart+blockSize])
            sequence = sequence.upper()
            if line.strand == '+':
                pass
            elif line.strand == '-':
                sequence = reverse_complement(sequence)
            else:
                raise Exception("Unexpected strand")
            if len(sequence) == 0:
                print(line.name)
                print(str(line))
                raise Exception
            record = SeqRecord(Seq(sequence), id=name, description="")
            yield record
    records = generate_records()
    SeqIO.write(records, fasta_handle, 'fasta')
    fasta_handle.close()
    bed_handle.close()

genome = read_genome(assembly)
sizes = read_chromosome_sizes(assembly)

if target == 'chrM':
    create_chrM_files()
elif target == 'rRNA':
    create_rRNA_files()
elif target == 'tRNA':
    create_tRNA_files()
elif target == 'gencode':
    create_gencode_files(target)
elif target == 'fantomcat':
    create_fantomcat_files(target)
elif target == 'novel':
    create_novel_files(target)
else:
    create_refseq_files(target)
